import asyncio
from abc import ABC, abstractmethod
from concurrent.futures import Future
from dataclasses import dataclass
from datetime import datetime
from typing import Callable, Dict, Iterator, List, Optional

from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.core import ModelMetadata, ModelOutput
from dbgpt.model.base import WorkerApplyOutput, WorkerSupportedModel
from dbgpt.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters
from dbgpt.util.parameter_utils import ParameterDescription


@dataclass
class WorkerRunData:
    host: str
    port: int
    worker_key: str
    worker: ModelWorker
    worker_params: ModelWorkerParameters
    model_params: ModelParameters
    stop_event: asyncio.Event
    semaphore: asyncio.Semaphore = None
    command_args: List[str] = None
    _heartbeat_future: Optional[Future] = None
    _last_heartbeat: Optional[datetime] = None

    def _to_print_key(self):
        model_name = self.model_params.model_name
        model_type = (
            self.model_params.model_type
            if hasattr(self.model_params, "model_type")
            else "text2vec"
        )
        host = self.host
        port = self.port
        return f"model {model_name}@{model_type}({host}:{port})"

    @property
    def stopped(self):
        """Check if the worker is stopped""" ""
        return self.stop_event.is_set()


class WorkerManager(ABC):
    @abstractmethod
    async def start(self):
        """Start worker manager

        Raises:
            Exception: if start worker manager not successfully
        """

    @abstractmethod
    async def stop(self, ignore_exception: bool = False):
        """Stop worker manager"""

    @abstractmethod
    def after_start(self, listener: Callable[["WorkerManager"], None]):
        """Add a listener after WorkerManager startup"""

    @abstractmethod
    async def get_model_instances(
        self, worker_type: str, model_name: str, healthy_only: bool = True
    ) -> List[WorkerRunData]:
        """Asynchronous get model instances by worker type and model name"""

    @abstractmethod
    async def get_all_model_instances(
        self, worker_type: str, healthy_only: bool = True
    ) -> List[WorkerRunData]:
        """Asynchronous get all model instances

        Args:
            worker_type (str): worker type
            healthy_only (bool, optional): only return healthy instances. Defaults to True.

        Returns:
            List[WorkerRunData]: worker run data list
        """

    @abstractmethod
    def sync_get_model_instances(
        self, worker_type: str, model_name: str, healthy_only: bool = True
    ) -> List[WorkerRunData]:
        """Get model instances by worker type and model name"""

    @abstractmethod
    async def select_one_instance(
        self, worker_type: str, model_name: str, healthy_only: bool = True
    ) -> WorkerRunData:
        """Asynchronous select one instance"""

    @abstractmethod
    def sync_select_one_instance(
        self, worker_type: str, model_name: str, healthy_only: bool = True
    ) -> WorkerRunData:
        """Select one instance"""

    @abstractmethod
    async def supported_models(self) -> List[WorkerSupportedModel]:
        """List supported models"""

    @abstractmethod
    async def model_startup(self, startup_req: WorkerStartupRequest):
        """Create and start a model instance"""

    @abstractmethod
    async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
        """Shutdown model instance"""

    @abstractmethod
    async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
        """Generate stream result, chat scene"""

    @abstractmethod
    async def generate(self, params: Dict) -> ModelOutput:
        """Generate non stream result"""

    @abstractmethod
    async def embeddings(self, params: Dict) -> List[List[float]]:
        """Asynchronous embed input"""

    @abstractmethod
    def sync_embeddings(self, params: Dict) -> List[List[float]]:
        """Embed input

        This function may be passed to a third-party system call for synchronous calls.
        We must provide a synchronous version.
        """

    @abstractmethod
    async def count_token(self, params: Dict) -> int:
        """Count token of prompt

        Args:
            params (Dict): parameters, eg. {"prompt": "hello", "model": "vicuna-13b-v1.5"}

        Returns:
            int: token count
        """

    @abstractmethod
    async def get_model_metadata(self, params: Dict) -> ModelMetadata:
        """Get model metadata

        Args:
            params (Dict): parameters, eg. {"model": "vicuna-13b-v1.5"}
        """

    @abstractmethod
    async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
        """Worker apply"""

    @abstractmethod
    async def parameter_descriptions(
        self, worker_type: str, model_name: str
    ) -> List[ParameterDescription]:
        """Get parameter descriptions of model"""


class WorkerManagerFactory(BaseComponent, ABC):
    name = ComponentType.WORKER_MANAGER_FACTORY.value

    def init_app(self, system_app: SystemApp):
        pass

    @abstractmethod
    def create(self) -> WorkerManager:
        """Create worker manager"""
